Skip to content

[core] Unify validation_step_outputs to always return list-of-lists#15470

Open
XuesongYang wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-unify-multi-dataloader-modelPT
Open

[core] Unify validation_step_outputs to always return list-of-lists#15470
XuesongYang wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-unify-multi-dataloader-modelPT

Conversation

@XuesongYang
Copy link
Collaborator

@XuesongYang XuesongYang commented Mar 6, 2026

What does this PR do ?

Unify ModelPT.validation_step_outputs (and test_step_outputs) to always return a list of lists, so a single dataloader is simply the N=1 case and subclasses no longer need to branch on the output shape. Normalize both _validation_dl and _test_dl to Optional[List[DataLoader]] via their respective resolvers.

Collection: Core, ASR, TTS, Audio

Changelog

  • modelPT.py: validation_step_outputs / test_step_outputs properties always return [[] for _ in range(num_dl)]; on_validation_epoch_end / on_test_epoch_end use len() == 1 instead of isinstance(..., dict) for single-vs-multi dispatch; empty-output guard updated to all(len(o) == 0 for o in ...) since [[]] is truthy; empty dataloader buckets skipped in multi-DL loop
  • model_utils.py: resolve_validation_dataloaders and resolve_test_dataloaders wrap bare DataLoader into [DataLoader] at both single-value paths, normalizing _validation_dl and _test_dl to Optional[List[DataLoader]]
  • modelPT.py (setup_multiple_validation_data): type annotation updated; isinstance guard simplified to truthiness check after normalization
  • 15 model files (ASR, TTS G2P, Audio): remove if/else branching in validation_step / test_step; always use self.validation_step_outputs[dataloader_idx].append(...)
  • transformer_bpe_models.py: remove isinstance(outputs[0], dict) normalization loop in multi_validation_epoch_end — base class now iterates dataloaders and calls it once per DL
  • audio_to_audio.py: simplify _get_num_dataloaders (both val and test) and logging callback setup after normalization
  • fastpitch.py, magpietts.py, magpietts_preference_optimization.py: add RuntimeError guard for len(validation_step_outputs) != 1; add early-return on empty outputs; use self.validation_step_outputs[0] consistently
  • ssl_models.py: fix EncDecMaskedTokenPredModel.test_step — was appending to validation_step_outputs instead of test_step_outputs
  • Test models (test_ema.py, check_for_ranks.py, test_ptl_stateless_timer.py): override multi_validation_epoch_end instead of on_validation_epoch_end; base class handles iteration, clearing, and per-DL prefix

Usage

No API changes for single-dataloader models — dataloader_idx=0 is the default. Subclasses should use the [dataloader_idx] indexing pattern:

# In validation_step:
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    metrics = self.compute_metrics(batch)
    self.validation_step_outputs[dataloader_idx].append(metrics)
    return metrics

# In multi_validation_epoch_end (preferred over on_validation_epoch_end):
def multi_validation_epoch_end(self, outputs, dataloader_idx=0):
    avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
    self.log("val_loss", avg_loss)

Copilot AI review requested due to automatic review settings March 6, 2026 02:11
@github-actions github-actions bot added core Changes to NeMo Core TTS ASR audio labels Mar 6, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR standardizes ModelPT.validation_step_outputs / test_step_outputs to use a consistent “list-of-lists” shape, simplifying subclass logic by removing single-vs-multi-dataloader branching and improving epoch-end dispatch/guards.

Changes:

  • Updated ModelPT epoch-end logic to dispatch based on len(outputs) (single vs multi dataloader) and to skip/guard empty per-dataloader outputs.
  • Normalized validation dataloader storage to List[DataLoader] in resolve_validation_dataloaders() and refactored many model validation_step/test_step implementations to always append via [dataloader_idx].
  • Updated unit tests and added a regression test to ensure multi_validation_epoch_end / multi_test_epoch_end are not called when all outputs are empty.

Reviewed changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/core_ptl/test_ptl_stateless_timer.py Updates test model hooks to the new list-of-lists output shape and adds an empty-epoch regression test.
tests/core_ptl/check_for_ranks.py Switches test model to append outputs via validation_step_outputs[dataloader_idx] and uses multi_validation_epoch_end.
tests/collections/common/test_ema.py Updates validation/test steps to append via [dataloader_idx] and uses multi_validation_epoch_end.
nemo/utils/model_utils.py Wraps single validation dataloaders into a list to normalize _validation_dl shape.
nemo/core/classes/modelPT.py Implements the unified list-of-lists output cache, updates epoch-end dispatch and empty-output guards.
nemo/collections/tts/models/magpietts_preference_optimization.py Removes single-vs-multi branching in validation output accumulation; adjusts epoch-end logic for the new shape.
nemo/collections/tts/models/magpietts.py Updates validation accumulation and epoch-end collection to use validation_step_outputs[0] consistently.
nemo/collections/tts/models/fastpitch.py Updates validation accumulation and epoch-end processing to use the new output structure.
nemo/collections/tts/g2p/models/t5.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/tts/g2p/models/ctc.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/audio/models/audio_to_audio.py Removes branching on dataloader count; simplifies callback setup in line with _validation_dl normalization.
nemo/collections/asr/models/transformer_bpe_models.py Simplifies multi-epoch-end logic to assume per-dataloader outputs (base class iterates dataloaders).
nemo/collections/asr/models/ssl_models.py Removes branching on dataloader count and fixes test_step to append to test_step_outputs.
nemo/collections/asr/models/sortformer_diar_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/slu_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/rnnt_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/label_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/hybrid_rnnt_ctc_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/ctc_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/classification_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/aed_multitask_models.py Removes branching on dataloader count; always appends via [dataloader_idx].

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 63 to 68
def _get_num_dataloaders(self, tag: str = 'val'):
if tag == 'val':
num_dataloaders = len(self._validation_dl) if isinstance(self._validation_dl, List) else 1
num_dataloaders = len(self._validation_dl) if self._validation_dl else 1
elif tag == 'test':
num_dataloaders = len(self._test_dl) if isinstance(self._test_dl, List) else 1
num_dataloaders = len(self._test_dl) if self._test_dl else 1
else:
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_num_dataloaders() now returns 1 when _validation_dl is an empty list. This changes the meaning from “number of configured dataloaders” to “at least 1”, which can cause _setup_metrics() to initialize metrics for a non-existent dataloader. Also, isinstance(self._test_dl, List) uses typing.List, which raises TypeError at runtime for isinstance checks; this should be replaced with a runtime type like (list, tuple) (and likely the same empty-list handling as for validation).

Copilot uses AI. Check for mistakes.
validation_step_outputs and test_step_outputs now always return a list
of lists (one inner list per dataloader), eliminating if/else branching
in every subclass that handles single-vs-multi dataloader shapes.

- validation_step_outputs property: returns [[] for _ in range(num_dl)]
- on_validation/test_epoch_end: len()==1 dispatch, all(len(o)==0 ...)
  empty guard, skip empty DL buckets in multi-DL loop
- Normalize _validation_dl to Optional[List[DataLoader]] in resolver
- 15 model files: self.validation_step_outputs[dataloader_idx].append()
- TTS models: RuntimeError guard for single-DL assumption
- Test models: override multi_validation_epoch_end, not on_*_epoch_end
- Bug fix: ssl_models test_step appended to wrong outputs list
- New test: empty outputs skip multi_epoch_end

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Made-with: Cursor
Same wrapping pattern as _validation_dl: wrap bare DataLoader into
[DataLoader] at both single-value paths in resolve_test_dataloaders.
Simplify isinstance guards in test_step_outputs and _get_num_dataloaders.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
@XuesongYang XuesongYang force-pushed the xueyang/pr-unify-multi-dataloader-modelPT branch from 75fc50d to 0194224 Compare March 9, 2026 22:27
def _get_num_dataloaders(self, tag: str = 'val'):
if tag == 'val':
num_dataloaders = len(self._validation_dl) if isinstance(self._validation_dl, List) else 1
num_dataloaders = len(self._validation_dl) if self._validation_dl else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically I think you want

Suggested change
num_dataloaders = len(self._validation_dl) if self._validation_dl else 1
num_dataloaders = len(self._validation_dl) if self._validation_dl is not None else 1

and similar everywhere else - make it clear this is a None check

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The truthiness check (if self._validation_dl) and if self._validation_dl is not None are equivalent.

because after resolver normalization, self._validation_dl is always either None or a non-empty List[DataLoader]. The None case is converted to [] by val_dataloader() (PTL 2.0+ doesn't accept None), but PTL then skips validation entirely for an empty list, so on_validation_start or _get_num_dataloaders never executes with [].

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we setting num_dataloaders = 1 when the list is empty then?

Copy link
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general, but before I approve - which models did you test it on, and did you test with both single and multi validation dataset?

Copy link
Member

@nithinraok nithinraok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general. How are you validating the changes? CI only runs on single gpu have you tested on multiple GPUs with multiple nodes?

@XuesongYang
Copy link
Collaborator Author

@pzelasko @nithinraok all questions regarding test by modeling are valid.

In my experience,

  • I only tested by training MagpieTTS on multi-node setup, and it worked as expected.
  • CI covers single-GPU across the changed ASR, TTS, and Audio model files

If we want to proceed with the move, we may need to test EncDecCTCModel, EncDecRNNTModel, FastPitch, AudioToAudio, ..., on multi-val dataloaders on multi-gpus integration tests. Do we have such resources/tests on CI?

@pzelasko
Copy link
Collaborator

@pzelasko @nithinraok all questions regarding test by modeling are valid.

In my experience,

  • I only tested by training MagpieTTS on multi-node setup, and it worked as expected.
  • CI covers single-GPU across the changed ASR, TTS, and Audio model files

If we want to proceed with the move, we may need to test EncDecCTCModel, EncDecRNNTModel, FastPitch, AudioToAudio, ..., on multi-val dataloaders on multi-gpus integration tests. Do we have such resources/tests on CI?

I'm afraid we don't have CI tests for these conditions. I think we can use up to 2 GPUs in CI currently (cc @chtruong814 to verify). It might be a good idea to add at least one test multi-GPU multi-validation training of some representative model for ~10 training steps with 1 validation and parse the CLI output to verify it's OK (Magpie is probably OK, or maybe Parakeet as it's most popular?).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants